import numpy as np
import h5py as hp
import netCDF4 as nc
import re
from typing import Union
import struct

HARBOHR_TO_KCALNM = 627.510/0.0529178
HAR_TO_KCAL = 627.510
HARANG_TO_KCALNM = 6275.10
R = 8.314/1000.
R_inv = 1./R

def RMSD(input_1, input_2):
    '''
    input_1: ndarray (N, M, A)
    input_2: ndarray (N', M, A)
    '''
    # Translation to the original point
    centroid_1 = np.mean(input_1, axis = 1, keepdims = True)
    centroid_2 = np.mean(input_2, axis = 1, keepdims = True)
    input_1 -= centroid_1
    input_2 -= centroid_2

    # Transpose Maxtrix Multi (N, A, M), (N', M, A) -> (N, N', A, A)
    covar = np.einsum('ikj, qkl -> iqjl', input_1, input_2)
    U, S, V = np.linalg.svd(covar)
    
    R = np.einsum('iqkj, iqjl -> iqkl', U, V)
    input_1_rotated = np.einsum('ijk, iqkl->iqjl', input_1, R)
    #Calculate RMSD
    displacement = input_1_rotated - input_2   #(N, N', M, A)
    
    rmsd = np.sqrt(np.mean(displacement**2, axis = (-1, -2))*displacement.shape[-1]) # (N, N')
    
    return rmsd


def input_writer(fmt: str, filepath: str, atom_label: list, coordinate, cmd = 'rhf/sto-3g', nproc = 32, mem = 64, charge = 0, spin = 1):
    if fmt == 'gjf':
        with open(filepath, 'w') as f:
            f.write(f'%nprocshared={nproc}  \n')
            f.write(f'%mem = {mem}GB   \n')
            f.write(cmd + '\n\n')#f.write('# b3lyp/aug-cc-pVDZ EmpiricalDispersion=GD3BJ force  \n\n')
            f.write('Blank  \n\n')
            f.write(f'{charge} {spin}  \n')
            for idx, crd in enumerate(coordinate):  
                    f.write('{0}      {1:7.5f}     {2:7.5f}     {3:7.5f} \n'.format(atom_label[idx], crd[0], crd[1], crd[2]))
            f.write('\n\n\n\n')
            f.close()
    if fmt == 'xyz':
        with open(filepath, 'w') as f:
            f.write(f'{len(atom_label)}\n')
            f.write(f'blank  \n')
            for idx, crd in enumerate(coordinate):  
                    f.write('{0}      {1:7.5f}     {2:7.5f}     {3:7.5f} \n'.format(atom_label[idx], crd[0], crd[1], crd[2]))
            f.write('\n\n\n\n')
            f.close()




def output_reader(fmt, filename, atom_num = 1):
    temp = 0
    count = 0
    force = []    
    # Gaussian out
    if fmt == 'gout':
        with open(filename, 'r') as f:
            lines = f.readlines()
            f.close()

        for line in lines:
            if '(Hartrees/Bohr)' in line:
                temp = 1
            if temp == 1:
                count += 1
                if count > 3 and count < 27:
                    force.append(line[:-1].split())
            if count>24:
                break

        for line in lines:
            if 'E(RB3LYP)' in line:
                potential = np.array(line.split()[4], dtype = np.float32)*HAR_TO_KCAL

        output = {'potential': potential, 'force' : force}
    # pyscf out
    elif fmt == 'pout':
        with open(filename, 'r') as f:
            lines = f.readlines()
            f.close()

        force_record_line = 0
        for idx, line in enumerate(lines):
            if ('-----' in line) and ('gradients' in line):
                force_record_line = idx
                break
        force = lines[force_record_line+2 : force_record_line+2+atom_num]
        force = [frc.split()[2:5] for frc in force]
        force = np.array(force, dtype = np.float)*HARBOHR_TO_KCALNM

        for line in lines:
            potential = np.array(line.split()[-1], dtype = np.float)*HAR_TO_KCAL
        output = {'potential': potential, 'force' : force}
    
    # sponge out         
    elif fmt == 'sout':
        with open(filename, 'r') as f:
            lines = f.readlines()
            f.close()
        items = lines[0].split()
        values = np.array([line.split() for line in lines[1:]], dtype = float).T
        output = {i:v for i,v in zip(items, values)}
    
    elif fmt == 'aout':
        with open(filename, 'r') as f:
            lines = f.readlines()
            f.close()
        info = lines[lines.index('   4.  RESULTS\n') + 2: lines.index('      R M S  F L U C T U A T I O N S\n')]  # Out infos
        info = [out.replace('=', ' ') for out in info]
        info = [out.split() for out in info] # Split out items thoroughly
        info = [o for out in info for o in out]  # Flatten
        frame_num = info.count('------------------------------------------------------------------------------') # Frame num + 1 for the appended average frame

        # Process item name
        frame_0 = info[:info.index('------------------------------------------------------------------------------')]
        temp = len(frame_0)
        for i in range(temp):
            try:
                float(frame_0[temp - i - 1])
                frame_0.remove(frame_0[temp - i - 1])
            except:
                continue
        try:
            frame_0.remove('1-4')
            frame_0.remove('1-4')
        except:
            pass
        items = frame_0 # Readout the abundent item name 

        #Process the value
        values = []
        for out in info:
            try:
                float(out)
                values.append(out)
            except:
                continue

        item_num = len(values)//frame_num
        values = values[:item_num*(frame_num-1)]  # Remove the useless numver X in 'A V E R A G E S   O V E R    X S T E P S  ' and data behind it
        values = np.array(values, dtype = float).reshape(frame_num - 1, -1).T # Remove the average frame
        output = {k:v for k,v in zip(items[:item_num], values)}
        
    else:
        raise ValueError('fmt should be gout, pout ,aout or sout, but you set:', fmt)
    
    return output

                    
def exchange_atomic_num_label(mode, atomic_info: np.ndarray):
    '''
    Return Atomic label array with the same shape of the atomic number
    '''
    num_to_label = {1:'H', 6: 'C', 7:'N', 8:'O', 9:'F', 11:'Na', 12:'Mg', 15:'P', 16:'S', 17: 'Cl'}
    label_to_num = {v:k for k,v in num_to_label.items()}
    if mode == 'ntol':
        return np.vectorize(num_to_label.get)(atomic_info)
    elif mode == 'lton':
        return np.vectorize(label_to_num.get)(atomic_info)
    else:
        raise ValueError("mode should only be ntol or lton")


def dat_reader(filename, frame_num, atom_num):
    with open(filename, 'rb') as f:
        traj = f.read(frame_num*atom_num*3*4)
        coordinate = np.array(struct.unpack(frame_num*atom_num*3*'f', traj), dtype = np.float32)
        f.close()
    return coordinate.reshape(frame_num, atom_num, 3)


def h5md_reader(filename):
    simuinfo = hp.File(filename)
    total = np.array(simuinfo['observables']['trajectory']['total_energy']['value'])
    kinetic = np.array(simuinfo['observables']['trajectory']['kinetic_energy']['value'])
    potential = np.array(simuinfo['observables']['trajectory']['potential_energy']['value'])
    temperature = np.array(simuinfo['observables']['trajectory']['temperature']['value'])
    energies = np.array(simuinfo['observables']['trajectory']['energies']['value'])
    traj = np.array(simuinfo['particles']['trajectory']['position']['value'])
    image = np.array(simuinfo['particles']['trajectory']['image']['value'])
    velocity = np.array(simuinfo['particles']['trajectory']['velocity']['value'])
    box  = np.array(simuinfo['particles']['trajectory']['box']['value'])
    force = np.array(simuinfo['particles']['trajectory']['foece']['value'])
    return {'total': total, 'kinetic': kinetic, 'potential':potential, 'temperature':temperature, 'energies':energies, 'traj':traj, 'image':image, 'velocity':velocity, 'box':box, 'force':force}


def nc_reader(filename):
    simuinfo = nc.Dataset(filename, 'r').variables
    time = np.array(simuinfo['time'])
    traj = np.array(simuinfo['coordinates'])
    force = np.array(['force'])
    return {'time': time, 'traj': traj, 'force': force}


def reweighting(U_bias: np.ndarray, T: float):
    '''
    Unit: kj/mol
    '''
    U_bias = U_bias - U_bias.mean()
    beta = R_inv/T
    print(beta)
    rwft = np.exp(U_bias*beta)
    return rwft

def free_energy(distribution: np.ndarray, T):
    '''
    Unit: kcal/mol
    '''
    beta_inv = T*R/4.12
    A = -np.log(distribution)*beta_inv
    return A




def trainset_concate(neo_trainset_name, neo_validset_name, trainset_sample_number, set_list, shuffle = True):
    
    crd_list = [set_i['R'] for set_i in set_list]
    frc_list = [set_i['F']*set_i['scale'] for set_i in set_list]
    potential = [set_i['E']*set_i['scale'] + set_i['shift'] for set_i in set_list]
    try:
        crd = np.concatenate(crd_list, axis = 0)
        frc = np.concatenate(frc_list, axis = 0)
        E = np.concatenate(potential, axis = 0)
    except:
        raise ValueError("Please check the dimension of the input")
    Z = set_list[0]['Z']
    atom_num = set_list[0]['natom']
    shift = np.mean(E).reshape(1)
    scale = np.std(E).reshape(1)
    avg_force_dis = np.array(np.linalg.norm(frc, axis = -1).mean()).reshape(-1)
    
    frc = frc/scale
    E = (E - shift)/scale

    if shuffle == False:
        np.savez(neo_trainset_name, R = crd[:trainset_sample_number], F = frc[:trainset_sample_number], E = E[:trainset_sample_number], Z = Z, natom = atom_num, shift = shift, scale = scale, avg_force_dis = avg_force_dis)
        np.savez(neo_validset_name, R = crd[trainset_sample_number:], F = frc[trainset_sample_number:], E = E[trainset_sample_number:], Z = Z, natom = atom_num, shift = shift, scale = scale, avg_force_dis = avg_force_dis)
    else:
        np.random.seed(114514)
        shuffle_idx = np.arange(E.shape[0])
        np.random.shuffle(shuffle_idx)
        np.savez(neo_trainset_name, R = crd[shuffle_idx][:trainset_sample_number], F = frc[shuffle_idx][:trainset_sample_number], E = E[shuffle_idx][:trainset_sample_number], Z = Z, natom = atom_num, shift = shift, scale = scale, avg_force_dis = avg_force_dis)
        np.savez(neo_validset_name, R = crd[shuffle_idx][trainset_sample_number:], F = frc[shuffle_idx][trainset_sample_number:], E = E[shuffle_idx][trainset_sample_number:], Z = Z, natom = atom_num, shift = shift, scale = scale, avg_force_dis = avg_force_dis)





